from typing import Callable, Dict, Any, List, Type
import numpy as np
import time
from tqdm import tqdm
from hypersense.utils.metrics import HPOEvaluationMetrics
from hypersense.utils.visualizer import HPOVisualizer
from hypersense.strategy.greedy_important_first import GreedyImportantFirstStrategy as _GIFBase


class HyperSensePipeline:

    def __init__(
        self,
        search_space: Dict[str, Any],
        full_dataset: List[Any],
        objective_fn: Callable[[Dict[str, Any], Any], float],
        test_fn: Callable[[Dict[str, Any]], float],
        sampler_class: Type,
        initial_optimizer_class: Type,
        whole_optimizer_class: Type,
        importance_analyzer_class: Type,
        strategy_class: Type,
        default_config: Dict[str, Any],
        seed: int = 42,
        best_known_optimum: float = 1.0,
        mode: str = "max",
    ):
        # Store all initialization parameters
        self.search_space = search_space
        self.full_dataset = full_dataset
        self.objective_fn = objective_fn
        self.test_fn = test_fn
        self.sampler_class = sampler_class
        self.initial_optimizer_class = initial_optimizer_class
        self.whole_optimizer_class = whole_optimizer_class
        self.importance_analyzer_class = importance_analyzer_class
        self.strategy_class = strategy_class
        self.default_config = default_config
        self.seed = seed
        self.best_known_optimum = best_known_optimum
        self.mode = mode

    def run(
        self,
        sample_ratio: float = 0.3,
        initial_trials: int = 40,
        total_trials: int = 100,
        step_trials: int = 20,
        min_trials_for_importance: int = 10,
        full_group_ratio: float = 0.2,
        verbose: bool = False,
        quiet: bool = True,
        top_k: int = 2,
    ) -> Dict[str, Any]:

        start_time = time.time()
        if not quiet:
            print("🔍 Starting HyperSense Pipeline...")

        # Sample a subset of the dataset for initial optimization
        sample_size = int(sample_ratio * len(self.full_dataset))
        sampler = self.sampler_class(self.full_dataset, sample_size, seed=self.seed)
        sampled_data = sampler.sample()

        # Split sampled data into features and labels
        X_sampled = np.array([row[:-1] for row in sampled_data])
        y_sampled = np.array([row[-1] for row in sampled_data])

        best_trial = None
        if not quiet:
            print(f"🚀 Running initial optimization on {sample_ratio:.0%} subset...")
        initial_trials_bar = tqdm(
            total=initial_trials, desc="Initial Optimization", disable=(not verbose)
        )

        # Define the objective function for initial optimization
        def initial_opt_objective(config):
            result = self.objective_fn(config, (X_sampled, y_sampled))
            if verbose:
                initial_trials_bar.update(1)
            return result

        # Initialize the optimizer for the initial phase
        initial_optimizer = self.initial_optimizer_class(
            space=self.search_space,
            objective_fn=initial_opt_objective,
            max_trials=initial_trials,
            seed=self.seed,
            verbose=verbose,
            mode=self.mode,
        )

        # Run initial optimization
        initial_trials_result = initial_optimizer.optimize()
        initial_trials_bar.close()

        # Process initial trials results: (config, result, elapsed_time)
        warmup_history = []
        for config, result, elapsed_time in initial_trials_result:
            warmup_history.append(
                {"config": config, "score": result, "elapsed_time": elapsed_time}
            )

        # === Decide which strategy to run ===
        self.strategy_name = self.strategy_class.__name__

        # Prepare full dataset for later optimization
        X = np.array([row[:-1] for row in self.full_dataset])
        y = np.array([row[-1] for row in self.full_dataset])

        merged_config = {}
        trial_history = []

        # === Case 1: GreedyImportantFirstStrategy ===
        if issubclass(self.strategy_class, _GIFBase):
            if not quiet:
                print("🎯 Running GreedyImportantFirstStrategy HPO...")

            # Build optimizer for a subspace
            def optimizer_builder(subspace, history, fixed_config, max_trials):
                filtered_points = []
                filtered_scores = []
                for t in history:
                    cfg = {k: t["config"][k] for k in subspace if k in t["config"]}
                    if len(cfg) == len(subspace):
                        filtered_points.append(cfg)
                        filtered_scores.append(t["score"])

                optimizer = self.whole_optimizer_class(
                    space=subspace,
                    objective_fn=lambda config: self.objective_fn(
                        {**fixed_config, **config}, (X, y)
                    ),
                    max_trials=max_trials,
                    seed=self.seed + len(history),
                    points_to_evaluate=filtered_points,
                    evaluated_rewards=filtered_scores,
                    verbose=verbose,
                    mode=self.mode,
                )
                return optimizer

            # Evaluate hyperparameter importance
            def importance_evaluator(configs, scores):
                try:
                    analyzer = self.importance_analyzer_class(seed=self.seed)
                except TypeError:
                    analyzer = self.importance_analyzer_class()
                analyzer.fit(configs, scores)
                importance = analyzer.explain()
                return importance

            # Initialize and run the selected strategy
            strategy = self.strategy_class(
                search_space=self.search_space,
                initial_trials=warmup_history,
                optimizer_builder=optimizer_builder,
                importance_evaluator=importance_evaluator,
                step_trials=int(step_trials),
                max_total_trials=int(total_trials),
                min_trials_for_importance=int(min_trials_for_importance),
                default_config=self.default_config,
                full_group_ratio=full_group_ratio,
                top_k=top_k,
            )

            try:
                trial_history = strategy.run()
                self.schedule = strategy.logs
                if trial_history:
                    select_best = min if self.mode == "min" else max
                    best_trial = select_best(
                        trial_history, key=lambda x: x["score"], default=None
                    )
                else:
                    best_trial = None
            except Exception as e:
                print(f"[Pipeline] Strategy run failed: {e}")
                trial_history = warmup_history
                self.schedule = []

        # === Case 2: SequentialGroupingStrategy ===
        else:
            if self.strategy_class.__name__ == "SequentialGroupingStrategy":
                if not quiet:
                    print("📊 Evaluating hyperparameter importance...")
                configs = [r["config"] for r in warmup_history]
                scores = [r["score"] for r in warmup_history]
                try:
                    analyzer = self.importance_analyzer_class(seed=self.seed)
                except TypeError:
                    analyzer = self.importance_analyzer_class()
                analyzer.fit(configs, scores)
                importance = analyzer.explain()

                # Sort importance and store ranks
                sorted_importance = sorted(
                    importance.items(), key=lambda x: x[1], reverse=True
                )
                self.importance_weights = importance
                self.importance_ranks = {
                    k: i + 1 for i, (k, _) in enumerate(sorted_importance)
                }
                if not quiet:
                    print("📈 Hyperparameter Importance:")
                    for k, v in sorted_importance:
                        print(f"  {k:<20}: {v:.4f} (Rank {self.importance_ranks[k]})")
            else:
                importance = None
            if not quiet:
                print("🔧 Building and running search strategy...")
            # Build the group schedule for optimization
            strategy = self.strategy_class(
                importance=importance, search_space=self.search_space, group_size=top_k
            )
            schedule = strategy.export_group_schedule(total_trials=total_trials)
            self.schedule = schedule
            if not quiet:
                print("🎯 Running full-data HPO by group...")

            # Iterate over each group in the schedule
            for i, stage in enumerate(schedule):
                group = stage["group"]
                budget = stage["budget"]

                subspace = {k: self.search_space[k] for k in group}

                combined_history = warmup_history + trial_history
                filtered_points = []
                filtered_scores = []

                # Filter previous trials for current subspace
                for t in combined_history:
                    cfg = {k: t["config"][k] for k in subspace if k in t["config"]}
                    if len(cfg) == len(subspace):
                        filtered_points.append(cfg)
                        filtered_scores.append(t["score"])

                if not quiet:
                    print(f"Group: {group} | Warm-start: {len(filtered_points)} points")
                    print(f"Budget: {budget}")
                group_bar = tqdm(
                    total=budget,
                    desc=f"Group {i+1}/{len(schedule)}",
                    disable=(not verbose),
                )

                # Define the objective function for this group
                def group_objective(config):
                    full_cfg = {**self.default_config, **merged_config, **config}
                    result = self.objective_fn(full_cfg, (X, y))
                    if verbose:
                        group_bar.update(1)
                    return result

                # Initialize optimizer for this group
                optimizer = self.whole_optimizer_class(
                    space=subspace,
                    objective_fn=group_objective,
                    max_trials=budget,
                    seed=self.seed + len(trial_history),
                    points_to_evaluate=filtered_points,
                    evaluated_rewards=filtered_scores,
                    verbose=verbose,
                    mode=self.mode,
                )

                trials = optimizer.optimize()
                group_bar.close()

                # Process optimizer results: (config, result, elapsed_time)
                trial_dicts = []
                group_elapsed_times = []

                for cfg, result, elapsed_time in trials:
                    trial_dict = {
                        "group_config": cfg,
                        "config": {**self.default_config, **merged_config, **cfg},
                        "score": result,
                        "elapsed_time": elapsed_time,
                    }
                    trial_dicts.append(trial_dict)
                    group_elapsed_times.append(elapsed_time)

                # Select the best trial in this group
                select_best = min if self.mode == "min" else max
                best_trial = select_best(trial_dicts, key=lambda x: x["score"])
                merged_config.update(best_trial["group_config"])

                # Add trials with elapsed_time to history
                trial_history.extend(
                    [
                        {
                            "config": t["config"],
                            "score": t["score"],
                            "elapsed_time": t["elapsed_time"],
                        }
                        for t in trial_dicts
                    ]
                )

                # Update schedule with timing information
                total_group_time = (
                    sum(group_elapsed_times) if group_elapsed_times else 0
                )
                avg_group_time = (
                    total_group_time / len(group_elapsed_times)
                    if group_elapsed_times
                    else 0
                )
                self.schedule[i]["total_elapsed_time"] = total_group_time
                self.schedule[i]["avg_elapsed_time"] = avg_group_time
                self.schedule[i]["best_score"] = best_trial["score"]

        # --- Robust final reporting ---
        self.trial_history = trial_history
        self.warmup_history = warmup_history
        self.elapsed_time = round(time.time() - start_time, 2)
        self.test_score = self.test_fn(best_trial["config"]) if best_trial else None

        # --- Collect metrics ---
        self.metrics = HPOEvaluationMetrics(
            trial_history=self.trial_history,
            warmup_history=self.warmup_history,
            pipeline_elapsed_time=self.elapsed_time,
            test_score=self.test_score,
            best_known_optimum=self.best_known_optimum,
            mode=self.mode,
        )
        return self

    def summary(
        self,
        show_metrics: bool = True,
        show_visualizations: bool = True,
        show_importance: bool = True,
    ):
        print("\n📊 Pipeline Summary")
        print("=" * 40)
        # Only show importance weights for SequentialGroupingStrategy
        if (
            getattr(self, "strategy_name", None) == "SequentialGroupingStrategy"
            and hasattr(self, "importance_weights")
            and show_importance
        ):
            print("\n🔢 Importance Weights:")
            for k, v in sorted(self.importance_weights.items(), key=lambda x: -x[1]):
                print(f"  {k:<20}: {v:.4f} (Rank {self.importance_ranks[k]})")

        print(f"\n🧭 Search Strategy: {self.strategy_name}")
        print("📅 Optimization Schedule:")
        if self.strategy_name == "GreedyImportantFirstStrategy":
            for i, s in enumerate(self.schedule):
                elapsed_info = ""
                if "total_elapsed_time" in s:
                    elapsed_info = f", Total Time: {s['total_elapsed_time']:.2f}s, Avg Time: {s.get('avg_elapsed_time', 0):.2f}s"
                print(
                    f"  Round {i + 1}: Group: {s['group']}, Trials: {s['trials']}, Best Score: {s.get('best_score', None)}{elapsed_info}"
                )
        else:
            for i, s in enumerate(self.schedule):
                elapsed_info = ""
                if "total_elapsed_time" in s:
                    elapsed_info = f", Total Time: {s['total_elapsed_time']:.2f}s, Avg Time: {s.get('avg_elapsed_time', 0):.2f}s"
                best_score_info = f", Best Score: {s.get('best_score', 'N/A')}"
                print(
                    f"  Group {i + 1}: {s['group']}, Budget: {s['budget']}{best_score_info}{elapsed_info}"
                )

        if show_metrics:
            # Print metrics summary
            for k, v in self.metrics.summary().items():
                print(f"{k}: {v}")

        if show_visualizations:
            visualizer = HPOVisualizer(self.trial_history)
            param_keys = list(self.search_space.keys())
            param_x = param_keys[0] if len(param_keys) > 0 else None
            param_y = param_keys[1] if len(param_keys) > 1 else None

            scores = [trial["score"] for trial in self.trial_history]
            z_min, z_max = min(scores), max(scores)

            if param_x and param_y:
                visualizer.summary(
                    param_x=param_x,
                    param_y=param_y,
                    z_range_truncate=True,
                    z_range=(z_min, z_max),
                    regret_curve=self.metrics.calculate_regret_curve(),
                )
            else:
                print("Not enough float hyperparameters to plot 3D surface.")

        print("=" * 40)
        return self
